{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MpkYHwCqk7W-"
      },
      "source": [
        "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBSdkbmGN2K-"
      },
      "source": [
        "### Copyright notice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_UbO9uhtBSX5"
      },
      "source": [
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eCopyright 2025 DeepMind Technologies Limited.\u003c/small\u003e\u003c/p\u003e\n",
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at \u003ca href=\"http://www.apache.org/licenses/LICENSE-2.0\"\u003ehttp://www.apache.org/licenses/LICENSE-2.0\u003c/a\u003e.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e\n",
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dNIJkb_FM2Ux"
      },
      "source": [
        "# Locomotion in The Playground! \u003ca href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/locomotion.ipynb\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/\u003e\u003c/a\u003e\n",
        "\n",
        "In this notebook, we'll walk through a few locomotion environments available in MuJoCo Playground.\n",
        "\n",
        "**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime \u003e Change runtime type\".\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Xqo7pyX-n72M"
      },
      "outputs": [],
      "source": [
        "#@title Install pre-requisites\n",
        "!pip install mujoco\n",
        "!pip install mujoco_mjx\n",
        "!pip install brax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "IbZxYDxzoz5R"
      },
      "outputs": [],
      "source": [
        "# @title Check if MuJoCo installation was successful\n",
        "\n",
        "import distutils.util\n",
        "import os\n",
        "import subprocess\n",
        "\n",
        "if subprocess.run('nvidia-smi').returncode:\n",
        "  raise RuntimeError(\n",
        "      'Cannot communicate with GPU. '\n",
        "      'Make sure you are using a GPU Colab runtime. '\n",
        "      'Go to the Runtime menu and select Choose runtime type.'\n",
        "  )\n",
        "\n",
        "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
        "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
        "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
        "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
        "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
        "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
        "  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
        "    f.write(\"\"\"{\n",
        "    \"file_format_version\" : \"1.0.0\",\n",
        "    \"ICD\" : {\n",
        "        \"library_path\" : \"libEGL_nvidia.so.0\"\n",
        "    }\n",
        "}\n",
        "\"\"\")\n",
        "\n",
        "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
        "print('Setting environment variable to use GPU rendering:')\n",
        "%env MUJOCO_GL=egl\n",
        "\n",
        "try:\n",
        "  print('Checking that the installation succeeded:')\n",
        "  import mujoco\n",
        "\n",
        "  mujoco.MjModel.from_xml_string('\u003cmujoco/\u003e')\n",
        "except Exception as e:\n",
        "  raise e from RuntimeError(\n",
        "      'Something went wrong during installation. Check the shell output above '\n",
        "      'for more information.\\n'\n",
        "      'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
        "      'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
        "  )\n",
        "\n",
        "print('Installation successful.')\n",
        "\n",
        "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
        "xla_flags = os.environ.get('XLA_FLAGS', '')\n",
        "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
        "os.environ['XLA_FLAGS'] = xla_flags"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "T5f4w3Kq2X14"
      },
      "outputs": [],
      "source": [
        "# @title Import packages for plotting and creating graphics\n",
        "import json\n",
        "import itertools\n",
        "import time\n",
        "from typing import Callable, List, NamedTuple, Optional, Union\n",
        "import numpy as np\n",
        "\n",
        "# Graphics and plotting.\n",
        "print(\"Installing mediapy:\")\n",
        "!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n",
        "!pip install -q mediapy\n",
        "import mediapy as media\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# More legible printing from numpy.\n",
        "np.set_printoptions(precision=3, suppress=True, linewidth=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ObF1UXrkb0Nd"
      },
      "outputs": [],
      "source": [
        "# @title Import MuJoCo, MJX, and Brax\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "from typing import Any, Dict, Sequence, Tuple, Union\n",
        "from brax import base\n",
        "from brax import envs\n",
        "from brax import math\n",
        "from brax.base import Base, Motion, Transform\n",
        "from brax.base import State as PipelineState\n",
        "from brax.envs.base import Env, PipelineEnv, State\n",
        "from brax.io import html, mjcf, model\n",
        "from brax.mjx.base import State as MjxState\n",
        "from brax.training.agents.ppo import networks as ppo_networks\n",
        "from brax.training.agents.ppo import train as ppo\n",
        "from brax.training.agents.sac import networks as sac_networks\n",
        "from brax.training.agents.sac import train as sac\n",
        "from etils import epath\n",
        "from flax import struct\n",
        "from flax.training import orbax_utils\n",
        "from IPython.display import HTML, clear_output\n",
        "import jax\n",
        "from jax import numpy as jp\n",
        "from matplotlib import pyplot as plt\n",
        "import mediapy as media\n",
        "from ml_collections import config_dict\n",
        "import mujoco\n",
        "from mujoco import mjx\n",
        "import numpy as np\n",
        "from orbax import checkpoint as ocp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "UoTLSx4cFRdy"
      },
      "outputs": [],
      "source": [
        "#@title Install MuJoCo Playground\n",
        "!pip install playground"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gYm2h7m8w3Nv"
      },
      "outputs": [],
      "source": [
        "#@title Import The Playground\n",
        "\n",
        "from mujoco_playground import wrapper\n",
        "from mujoco_playground import registry"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LcibXbyKt4FI"
      },
      "source": [
        "# Locomotion\n",
        "\n",
        "MuJoCo Playground contains a host of quadrupedal and bipedal environments (all listed below after running the command)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ox0Gze9Ct5AM"
      },
      "outputs": [],
      "source": [
        "registry.locomotion.ALL_ENVS"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_R01tjWfI-i6"
      },
      "source": [
        "# Quadrupedal\n",
        "\n",
        "Let's jump right into quadrupedal locomotion! While we have environments available for the Google Barkour and Boston Dynamics Spot robots, the Unitree Go1 environment contains the most trainable policies that were transferred onto the real robot. We'll go right ahead and show a few policies using the Unitree Go1!\n",
        "\n",
        "First, let's train a joystick policy, which tracks linear and yaw velocity commands."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kPJeoQeEJBSA"
      },
      "outputs": [],
      "source": [
        "env_name = 'Go1JoystickFlatTerrain'\n",
        "env = registry.load(env_name)\n",
        "env_cfg = registry.get_default_config(env_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6n9UT9N1wR5K"
      },
      "outputs": [],
      "source": [
        "env_cfg"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Thm7nZueM4cz"
      },
      "source": [
        "## Joystick\n",
        "\n",
        "Let's train the joystick policy and visualize rollouts:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9T_UVZYLDdM"
      },
      "outputs": [],
      "source": [
        "from mujoco_playground.config import locomotion_params\n",
        "ppo_params = locomotion_params.brax_ppo_config(env_name)\n",
        "ppo_params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aefr2OS01D9g"
      },
      "source": [
        "Domain randomization was used to make the policy robust to sim-to-real transfer. Certain environments in the Playground have domain randomization functions implemented. They're available in the registry and can be passed directly to brax RL algorithms. The [domain randomization](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/locomotion/go1/randomize.py) function randomizes over friction, armature, center of mass of the torso, and link masses, amongst other simulation parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UVA4Bn681DZT"
      },
      "outputs": [],
      "source": [
        "registry.get_domain_randomizer(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vBEEQyY6M5OC"
      },
      "source": [
        "### Train\n",
        "\n",
        "The policy takes 7 minutes to train on an RTX 4090."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKFzyP7wM5OD"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  times.append(datetime.now())\n",
        "  x_data.append(num_steps)\n",
        "  y_data.append(metrics[\"eval/episode_reward\"])\n",
        "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "  plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
        "  plt.xlabel(\"# environment steps\")\n",
        "  plt.ylabel(\"reward per episode\")\n",
        "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "  display(plt.gcf())\n",
        "\n",
        "randomizer = registry.get_domain_randomizer(env_name)\n",
        "ppo_training_params = dict(ppo_params)\n",
        "network_factory = ppo_networks.make_ppo_networks\n",
        "if \"network_factory\" in ppo_params:\n",
        "  del ppo_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      ppo_networks.make_ppo_networks,\n",
        "      **ppo_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    ppo.train, **dict(ppo_training_params),\n",
        "    network_factory=network_factory,\n",
        "    randomization_fn=randomizer,\n",
        "    progress_fn=progress\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FGrlulWbM5OD"
      },
      "outputs": [],
      "source": [
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=env,\n",
        "    eval_env=registry.load(env_name, config=env_cfg),\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AUxSNhq3UqmC"
      },
      "source": [
        "Let's rollout and render the resulting policy!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RBM89g5A2Yoi"
      },
      "outputs": [],
      "source": [
        "# Enable perturbation in the eval env.\n",
        "env_cfg = registry.get_default_config(env_name)\n",
        "env_cfg.pert_config.enable = True\n",
        "env_cfg.pert_config.velocity_kick = [3.0, 6.0]\n",
        "env_cfg.pert_config.kick_wait_times = [5.0, 15.0]\n",
        "env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi]\n",
        "eval_env = registry.load(env_name, config=env_cfg)\n",
        "velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.\n",
        "kick_duration_range = [0.05, 0.2]\n",
        "\n",
        "jit_reset = jax.jit(eval_env.reset)\n",
        "jit_step = jax.jit(eval_env.step)\n",
        "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "C_1CY9xDoUKw"
      },
      "outputs": [],
      "source": [
        "#@title Rollout and Render\n",
        "from mujoco_playground._src.gait import draw_joystick_command\n",
        "\n",
        "x_vel = 0.0  #@param {type: \"number\"}\n",
        "y_vel = 0.0  #@param {type: \"number\"}\n",
        "yaw_vel = 3.14  #@param {type: \"number\"}\n",
        "\n",
        "\n",
        "def sample_pert(rng):\n",
        "  rng, key1, key2 = jax.random.split(rng, 3)\n",
        "  pert_mag = jax.random.uniform(\n",
        "      key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1]\n",
        "  )\n",
        "  duration_seconds = jax.random.uniform(\n",
        "      key2, minval=kick_duration_range[0], maxval=kick_duration_range[1]\n",
        "  )\n",
        "  duration_steps = jp.round(duration_seconds / eval_env.dt).astype(jp.int32)\n",
        "  state.info[\"pert_mag\"] = pert_mag\n",
        "  state.info[\"pert_duration\"] = duration_steps\n",
        "  state.info[\"pert_duration_seconds\"] = duration_seconds\n",
        "  return rng\n",
        "\n",
        "\n",
        "rng = jax.random.PRNGKey(0)\n",
        "rollout = []\n",
        "modify_scene_fns = []\n",
        "\n",
        "swing_peak = []\n",
        "rewards = []\n",
        "linvel = []\n",
        "angvel = []\n",
        "track = []\n",
        "foot_vel = []\n",
        "rews = []\n",
        "contact = []\n",
        "command = jp.array([x_vel, y_vel, yaw_vel])\n",
        "\n",
        "state = jit_reset(rng)\n",
        "if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
        "  rng = sample_pert(rng)\n",
        "state.info[\"command\"] = command\n",
        "for i in range(env_cfg.episode_length):\n",
        "  if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
        "    rng = sample_pert(rng)\n",
        "  act_rng, rng = jax.random.split(rng)\n",
        "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "  state = jit_step(state, ctrl)\n",
        "  state.info[\"command\"] = command\n",
        "  rews.append(\n",
        "      {k: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
        "  )\n",
        "  rollout.append(state)\n",
        "  swing_peak.append(state.info[\"swing_peak\"])\n",
        "  rewards.append(\n",
        "      {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
        "  )\n",
        "  linvel.append(env.get_global_linvel(state.data))\n",
        "  angvel.append(env.get_gyro(state.data))\n",
        "  track.append(\n",
        "      env._reward_tracking_lin_vel(\n",
        "          state.info[\"command\"], env.get_local_linvel(state.data)\n",
        "      )\n",
        "  )\n",
        "\n",
        "  feet_vel = state.data.sensordata[env._foot_linvel_sensor_adr]\n",
        "  vel_xy = feet_vel[..., :2]\n",
        "  vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))\n",
        "  foot_vel.append(vel_norm)\n",
        "\n",
        "  contact.append(state.info[\"last_contact\"])\n",
        "\n",
        "  xyz = np.array(state.data.xpos[env._torso_body_id])\n",
        "  xyz += np.array([0, 0, 0.2])\n",
        "  x_axis = state.data.xmat[env._torso_body_id, 0]\n",
        "  yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
        "  modify_scene_fns.append(\n",
        "      functools.partial(\n",
        "          draw_joystick_command,\n",
        "          cmd=state.info[\"command\"],\n",
        "          xyz=xyz,\n",
        "          theta=yaw,\n",
        "          scl=abs(state.info[\"command\"][0])\n",
        "          / env_cfg.command_config.a[0],\n",
        "      )\n",
        "  )\n",
        "\n",
        "\n",
        "render_every = 2\n",
        "fps = 1.0 / eval_env.dt / render_every\n",
        "traj = rollout[::render_every]\n",
        "mod_fns = modify_scene_fns[::render_every]\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True\n",
        "\n",
        "frames = eval_env.render(\n",
        "    traj,\n",
        "    camera=\"track\",\n",
        "    scene_option=scene_option,\n",
        "    width=640,\n",
        "    height=480,\n",
        "    modify_scene_fns=mod_fns,\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1QHdoJ2r30En"
      },
      "source": [
        "Let's visualize the feet positions and the positional drift compared to the commanded linear and angular velocity. This is useful for debugging how well the policy follows the commands!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gyyynm3ozEet"
      },
      "outputs": [],
      "source": [
        "#@title Plot each foot in a 2x2 grid.\n",
        "\n",
        "swing_peak = jp.array(swing_peak)\n",
        "names = [\"FR\", \"FL\", \"RR\", \"RL\"]\n",
        "colors = [\"r\", \"g\", \"b\", \"y\"]\n",
        "fig, axs = plt.subplots(2, 2)\n",
        "for i, ax in enumerate(axs.flat):\n",
        "  ax.plot(swing_peak[:, i], color=colors[i])\n",
        "  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])\n",
        "  ax.axhline(env_cfg.reward_config.max_foot_height, color=\"k\", linestyle=\"--\")\n",
        "  ax.set_title(names[i])\n",
        "  ax.set_xlabel(\"time\")\n",
        "  ax.set_ylabel(\"height\")\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "linvel_x = jp.array(linvel)[:, 0]\n",
        "linvel_y = jp.array(linvel)[:, 1]\n",
        "angvel_yaw = jp.array(angvel)[:, 2]\n",
        "\n",
        "# Plot whether velocity is within the command range.\n",
        "linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode=\"same\")\n",
        "linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode=\"same\")\n",
        "angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode=\"same\")\n",
        "\n",
        "fig, axes = plt.subplots(3, 1, figsize=(10, 10))\n",
        "axes[0].plot(linvel_x)\n",
        "axes[1].plot(linvel_y)\n",
        "axes[2].plot(angvel_yaw)\n",
        "\n",
        "axes[0].set_ylim(\n",
        "    -env_cfg.command_config.a[0], env_cfg.command_config.a[0]\n",
        ")\n",
        "axes[1].set_ylim(\n",
        "    -env_cfg.command_config.a[1], env_cfg.command_config.a[1]\n",
        ")\n",
        "axes[2].set_ylim(\n",
        "    -env_cfg.command_config.a[2], env_cfg.command_config.a[2]\n",
        ")\n",
        "\n",
        "for i, ax in enumerate(axes):\n",
        "  ax.axhline(state.info[\"command\"][i], color=\"red\", linestyle=\"--\")\n",
        "\n",
        "labels = [\"dx\", \"dy\", \"dyaw\"]\n",
        "for i, ax in enumerate(axes):\n",
        "  ax.set_ylabel(labels[i])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t1QAHuYBQBbl"
      },
      "source": [
        "Now let's visualize what it looks like to slowly increase linear velocity commands."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Q0EuQiVlzh5u"
      },
      "outputs": [],
      "source": [
        "#@title Slowly increase linvel commands\n",
        "\n",
        "rng = jax.random.PRNGKey(0)\n",
        "rollout = []\n",
        "modify_scene_fns = []\n",
        "swing_peak = []\n",
        "linvel = []\n",
        "angvel = []\n",
        "\n",
        "x = -0.25\n",
        "command = jp.array([x, 0, 0])\n",
        "state = jit_reset(rng)\n",
        "for i in range(1_400):\n",
        "  # Increase the forward velocity by 0.25 m/s every 200 steps.\n",
        "  if i % 200 == 0:\n",
        "    x += 0.25\n",
        "    print(f\"Setting x to {x}\")\n",
        "    command = jp.array([x, 0, 0])\n",
        "  state.info[\"command\"] = command\n",
        "  if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
        "    rng = sample_pert(rng)\n",
        "  act_rng, rng = jax.random.split(rng)\n",
        "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "  state = jit_step(state, ctrl)\n",
        "  rollout.append(state)\n",
        "  swing_peak.append(state.info[\"swing_peak\"])\n",
        "  linvel.append(env.get_global_linvel(state.data))\n",
        "  angvel.append(env.get_gyro(state.data))\n",
        "  xyz = np.array(state.data.xpos[env._torso_body_id])\n",
        "  xyz += np.array([0, 0, 0.2])\n",
        "  x_axis = state.data.xmat[env._torso_body_id, 0]\n",
        "  yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
        "  modify_scene_fns.append(\n",
        "      functools.partial(\n",
        "          draw_joystick_command,\n",
        "          cmd=command,\n",
        "          xyz=xyz,\n",
        "          theta=yaw,\n",
        "          scl=abs(command[0]) / env_cfg.command_config.a[0],\n",
        "      )\n",
        "  )\n",
        "\n",
        "\n",
        "# Plot each foot in a 2x2 grid.\n",
        "swing_peak = jp.array(swing_peak)\n",
        "names = [\"FR\", \"FL\", \"RR\", \"RL\"]\n",
        "colors = [\"r\", \"g\", \"b\", \"y\"]\n",
        "fig, axs = plt.subplots(2, 2)\n",
        "for i, ax in enumerate(axs.flat):\n",
        "  ax.plot(swing_peak[:, i], color=colors[i])\n",
        "  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])\n",
        "  ax.axhline(env_cfg.reward_config.max_foot_height, color=\"k\", linestyle=\"--\")\n",
        "  ax.set_title(names[i])\n",
        "  ax.set_xlabel(\"time\")\n",
        "  ax.set_ylabel(\"height\")\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "linvel_x = jp.array(linvel)[:, 0]\n",
        "linvel_y = jp.array(linvel)[:, 1]\n",
        "angvel_yaw = jp.array(angvel)[:, 2]\n",
        "\n",
        "# Plot whether velocity is within the command range.\n",
        "linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode=\"same\")\n",
        "linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode=\"same\")\n",
        "angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode=\"same\")\n",
        "\n",
        "fig, axes = plt.subplots(3, 1, figsize=(10, 10))\n",
        "axes[0].plot(linvel_x)\n",
        "axes[1].plot(linvel_y)\n",
        "axes[2].plot(angvel_yaw)\n",
        "\n",
        "axes[0].set_ylim(\n",
        "    -env_cfg.command_config.a[0], env_cfg.command_config.a[0]\n",
        ")\n",
        "axes[1].set_ylim(\n",
        "    -env_cfg.command_config.a[1], env_cfg.command_config.a[1]\n",
        ")\n",
        "axes[2].set_ylim(\n",
        "    -env_cfg.command_config.a[2], env_cfg.command_config.a[2]\n",
        ")\n",
        "\n",
        "for i, ax in enumerate(axes):\n",
        "  ax.axhline(state.info[\"command\"][i], color=\"red\", linestyle=\"--\")\n",
        "\n",
        "labels = [\"dx\", \"dy\", \"dyaw\"]\n",
        "for i, ax in enumerate(axes):\n",
        "  ax.set_ylabel(labels[i])\n",
        "\n",
        "\n",
        "render_every = 2\n",
        "fps = 1.0 / eval_env.dt / render_every\n",
        "print(f\"fps: {fps}\")\n",
        "\n",
        "traj = rollout[::render_every]\n",
        "mod_fns = modify_scene_fns[::render_every]\n",
        "assert len(traj) == len(mod_fns)\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True\n",
        "\n",
        "frames = eval_env.render(\n",
        "    traj,\n",
        "    camera=\"track\",\n",
        "    height=480,\n",
        "    width=640,\n",
        "    modify_scene_fns=mod_fns,\n",
        "    scene_option=scene_option,\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0RHZvXgmzrEJ"
      },
      "source": [
        "## Handstand\n",
        "\n",
        "Additional policies are available for the Unitree Go1 such as fall-recovery, handstand, and footstand policies. We'll use the handstand policy as an opportunity to demonstrate finetuning policies from prior checkpoints. This will allow us to quickly iterate on training curriculums by modifying the enviornment config between runs.\n",
        "\n",
        "For the Go1 handstand policy, we'll first train with the default configuration, and then add an energy penalty to make the policy smoother and more likely to transfer onto the robot."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RYriZOAxzEk_"
      },
      "outputs": [],
      "source": [
        "from mujoco_playground.config import locomotion_params\n",
        "\n",
        "env_name = 'Go1Handstand'\n",
        "env = registry.load(env_name)\n",
        "env_cfg = registry.get_default_config(env_name)\n",
        "ppo_params = locomotion_params.brax_ppo_config(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3nB5ugbdS5kk"
      },
      "source": [
        "Let's create a checkpoint directory and then train a policy with checkpointing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EyEDpHisS7eO"
      },
      "outputs": [],
      "source": [
        "ckpt_path = epath.Path(\"checkpoints\").resolve() / env_name\n",
        "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
        "print(f\"{ckpt_path}\")\n",
        "\n",
        "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
        "  json.dump(env_cfg.to_dict(), fp, indent=4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lCRUYofXSNGT"
      },
      "outputs": [],
      "source": [
        "#@title Training fn definition\n",
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "\n",
        "def policy_params_fn(current_step, make_policy, params):\n",
        "  del make_policy  # Unused.\n",
        "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
        "  save_args = orbax_utils.save_args_from_target(params)\n",
        "  path = ckpt_path / f\"{current_step}\"\n",
        "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
        "\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  times.append(datetime.now())\n",
        "  x_data.append(num_steps)\n",
        "  y_data.append(metrics[\"eval/episode_reward\"])\n",
        "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "  plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
        "  plt.xlabel(\"# environment steps\")\n",
        "  plt.ylabel(\"reward per episode\")\n",
        "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "  display(plt.gcf())\n",
        "\n",
        "randomizer = registry.get_domain_randomizer(env_name)\n",
        "ppo_training_params = dict(ppo_params)\n",
        "network_factory = ppo_networks.make_ppo_networks\n",
        "if \"network_factory\" in ppo_params:\n",
        "  del ppo_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      ppo_networks.make_ppo_networks,\n",
        "      **ppo_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    ppo.train, **dict(ppo_training_params),\n",
        "    network_factory=network_factory,\n",
        "    randomization_fn=randomizer,\n",
        "    progress_fn=progress,\n",
        "    policy_params_fn=policy_params_fn,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A1oK80x1anPp"
      },
      "source": [
        "The initial policy takes 8 minutes to train on an RTX 4090."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MY6P3abhSNGU"
      },
      "outputs": [],
      "source": [
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=registry.load(env_name, config=env_cfg),\n",
        "    eval_env=registry.load(env_name, config=env_cfg),\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4s6PkZ4GWV4Z"
      },
      "source": [
        "Let's visualize the current policy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "WiWOtc_6WbcX"
      },
      "outputs": [],
      "source": [
        "#@title Rollout and Render\n",
        "inference_fn = make_inference_fn(params, deterministic=True)\n",
        "jit_inference_fn = jax.jit(inference_fn)\n",
        "\n",
        "eval_env = registry.load(env_name, config=env_cfg)\n",
        "jit_reset = jax.jit(eval_env.reset)\n",
        "jit_step = jax.jit(eval_env.step)\n",
        "\n",
        "rng = jax.random.PRNGKey(12345)\n",
        "rollout = []\n",
        "rewards = []\n",
        "torso_height = []\n",
        "actions = []\n",
        "torques = []\n",
        "power = []\n",
        "qfrc_constraint = []\n",
        "qvels = []\n",
        "power1 = []\n",
        "power2 = []\n",
        "for _ in range(10):\n",
        "  rng, reset_rng = jax.random.split(rng)\n",
        "  state = jit_reset(reset_rng)\n",
        "  for i in range(env_cfg.episode_length // 2):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    actions.append(ctrl)\n",
        "    state = jit_step(state, ctrl)\n",
        "    rollout.append(state)\n",
        "    rewards.append(\n",
        "        {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
        "    )\n",
        "    torso_height.append(state.data.qpos[2])\n",
        "    torques.append(state.data.actuator_force)\n",
        "    qvel = state.data.qvel[6:]\n",
        "    power.append(jp.sum(jp.abs(qvel * state.data.actuator_force)))\n",
        "    qfrc_constraint.append(jp.linalg.norm(state.data.qfrc_constraint[6:]))\n",
        "    qvels.append(jp.max(jp.abs(qvel)))\n",
        "    frc = state.data.actuator_force\n",
        "    qvel = state.data.qvel[6:]\n",
        "    power1.append(jp.sum(frc * qvel))\n",
        "    power2.append(jp.sum(jp.abs(frc * qvel)))\n",
        "\n",
        "\n",
        "render_every = 2\n",
        "fps = 1.0 / eval_env.dt / render_every\n",
        "traj = rollout[::render_every]\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
        "\n",
        "frames = eval_env.render(\n",
        "    traj, camera=\"side\", scene_option=scene_option, height=480, width=640\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)\n",
        "\n",
        "power = jp.array(power1)\n",
        "print(f\"Max power: {jp.max(power)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v5p0Z3PPSRik"
      },
      "source": [
        "Notice that the above policy looks jittery and unlikely to transfer on the robot. The max power output is also quite high.\n",
        "\n",
        "The sim-to-real deployment of the handstand policy was trained using a curriculum on the `energy_termination_threshold`, `energy` and `dof_acc`, which are config values that penalize high torques and high power output. Let's finetune the above policy with a decreased  `energy_termination_threshold`, as well as non-zero values for `energy` and `dof_acc` rewards to get a smoother policy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hrjoVL-_WN-r"
      },
      "source": [
        "### Finetune the previous checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jTxAySRSSu96"
      },
      "outputs": [],
      "source": [
        "env_cfg = registry.get_default_config(env_name)\n",
        "env_cfg.energy_termination_threshold = 400  # lower energy termination threshold\n",
        "env_cfg.reward_config.energy = -0.003  # non-zero negative `energy` reward\n",
        "env_cfg.reward_config.dof_acc = -2.5e-7  # non-zero negative `dof_acc` reward\n",
        "\n",
        "FINETUNE_PATH = epath.Path(ckpt_path)\n",
        "latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
        "latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n",
        "latest_ckpts.sort(key=lambda x: int(x.name))\n",
        "latest_ckpt = latest_ckpts[-1]\n",
        "restore_checkpoint_path = latest_ckpt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_M5IqOR6z4bV"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=registry.load(env_name, config=env_cfg),\n",
        "    eval_env=registry.load(env_name, config=env_cfg),\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        "    restore_checkpoint_path=restore_checkpoint_path,  # restore from the checkpoint!\n",
        "    seed=1,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tzG8eY2lz4dk"
      },
      "outputs": [],
      "source": [
        "#@title Rollout and Render Finetune Policy\n",
        "inference_fn = make_inference_fn(params, deterministic=True)\n",
        "jit_inference_fn = jax.jit(inference_fn)\n",
        "\n",
        "eval_env = registry.load(env_name, config=env_cfg)\n",
        "jit_reset = jax.jit(eval_env.reset)\n",
        "jit_step = jax.jit(eval_env.step)\n",
        "\n",
        "rng = jax.random.PRNGKey(12345)\n",
        "rollout = []\n",
        "rewards = []\n",
        "torso_height = []\n",
        "actions = []\n",
        "torques = []\n",
        "power = []\n",
        "qfrc_constraint = []\n",
        "qvels = []\n",
        "power1 = []\n",
        "power2 = []\n",
        "for _ in range(10):\n",
        "  rng, reset_rng = jax.random.split(rng)\n",
        "  state = jit_reset(reset_rng)\n",
        "  for i in range(env_cfg.episode_length // 2):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    actions.append(ctrl)\n",
        "    state = jit_step(state, ctrl)\n",
        "    rollout.append(state)\n",
        "    rewards.append(\n",
        "        {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
        "    )\n",
        "    torso_height.append(state.data.qpos[2])\n",
        "    torques.append(state.data.actuator_force)\n",
        "    qvel = state.data.qvel[6:]\n",
        "    power.append(jp.sum(jp.abs(qvel * state.data.actuator_force)))\n",
        "    qfrc_constraint.append(jp.linalg.norm(state.data.qfrc_constraint[6:]))\n",
        "    qvels.append(jp.max(jp.abs(qvel)))\n",
        "    frc = state.data.actuator_force\n",
        "    qvel = state.data.qvel[6:]\n",
        "    power1.append(jp.sum(frc * qvel))\n",
        "    power2.append(jp.sum(jp.abs(frc * qvel)))\n",
        "\n",
        "\n",
        "render_every = 2\n",
        "fps = 1.0 / eval_env.dt / render_every\n",
        "traj = rollout[::render_every]\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
        "\n",
        "frames = eval_env.render(\n",
        "    traj, camera=\"side\", scene_option=scene_option, height=480, width=640\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)\n",
        "\n",
        "power = jp.array(power1)\n",
        "print(f\"Max power: {jp.max(power)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yCyibqGMiAca"
      },
      "source": [
        "The final policy should exhibit smoother behavior and have less power output! Feel free to finetune the policy some more using different reward terms to get the best behavior."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26o77FfWXvVp"
      },
      "source": [
        "# Bipedal\n",
        "\n",
        "MuJoCo Playground also comes with a host of bipedal environments, such as the Berkely Humanoid and the Unitree G1/H1. Let's demonstrate a joystick policy on the Berkeley Humanoid. The initial policy takes 17 minutes to train on an RTX 4090."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ESNd18FUanPt"
      },
      "outputs": [],
      "source": [
        "env_name = 'BerkeleyHumanoidJoystickFlatTerrain'\n",
        "env = registry.load(env_name)\n",
        "env_cfg = registry.get_default_config(env_name)\n",
        "ppo_params = locomotion_params.brax_ppo_config(env_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nibLoRu8anPt"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "randomizer = registry.get_domain_randomizer(env_name)\n",
        "ppo_training_params = dict(ppo_params)\n",
        "network_factory = ppo_networks.make_ppo_networks\n",
        "if \"network_factory\" in ppo_params:\n",
        "  del ppo_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      ppo_networks.make_ppo_networks,\n",
        "      **ppo_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    ppo.train, **dict(ppo_training_params),\n",
        "    network_factory=network_factory,\n",
        "    randomization_fn=randomizer,\n",
        "    progress_fn=progress\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "16dqomv0anPt"
      },
      "outputs": [],
      "source": [
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=env,\n",
        "    eval_env=registry.load(env_name, config=env_cfg),\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sBHDF-JFanPt",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Rollout and Render\n",
        "from mujoco_playground._src.gait import draw_joystick_command\n",
        "\n",
        "env = registry.load(env_name)\n",
        "eval_env = registry.load(env_name)\n",
        "jit_reset = jax.jit(eval_env.reset)\n",
        "jit_step = jax.jit(eval_env.step)\n",
        "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))\n",
        "\n",
        "rng = jax.random.PRNGKey(1)\n",
        "\n",
        "rollout = []\n",
        "modify_scene_fns = []\n",
        "\n",
        "x_vel = 1.0  #@param {type: \"number\"}\n",
        "y_vel = 0.0  #@param {type: \"number\"}\n",
        "yaw_vel = 0.0  #@param {type: \"number\"}\n",
        "command = jp.array([x_vel, y_vel, yaw_vel])\n",
        "\n",
        "phase_dt = 2 * jp.pi * eval_env.dt * 1.5\n",
        "phase = jp.array([0, jp.pi])\n",
        "\n",
        "for j in range(1):\n",
        "  print(f\"episode {j}\")\n",
        "  state = jit_reset(rng)\n",
        "  state.info[\"phase_dt\"] = phase_dt\n",
        "  state.info[\"phase\"] = phase\n",
        "  for i in range(env_cfg.episode_length):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    state = jit_step(state, ctrl)\n",
        "    if state.done:\n",
        "      break\n",
        "    state.info[\"command\"] = command\n",
        "    rollout.append(state)\n",
        "\n",
        "    xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n",
        "    xyz += np.array([0, 0.0, 0])\n",
        "    x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n",
        "    yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
        "    modify_scene_fns.append(\n",
        "        functools.partial(\n",
        "            draw_joystick_command,\n",
        "            cmd=state.info[\"command\"],\n",
        "            xyz=xyz,\n",
        "            theta=yaw,\n",
        "            scl=np.linalg.norm(state.info[\"command\"]),\n",
        "        )\n",
        "    )\n",
        "\n",
        "render_every = 1\n",
        "fps = 1.0 / eval_env.dt / render_every\n",
        "print(f\"fps: {fps}\")\n",
        "traj = rollout[::render_every]\n",
        "mod_fns = modify_scene_fns[::render_every]\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False\n",
        "\n",
        "frames = eval_env.render(\n",
        "    traj,\n",
        "    camera=\"track\",\n",
        "    scene_option=scene_option,\n",
        "    width=640*2,\n",
        "    height=480,\n",
        "    modify_scene_fns=mod_fns,\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CBtrAqns35sI"
      },
      "source": [
        "🙌 Hasta la vista!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true,
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
